#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from .lookup_args import *


{% if is_fbcode %}
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
torch.ops.load_library(
    "//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings"
)
{% else %}
#import os
#torch.ops.load_library(os.path.join(os.path.join(os.path.dirname(os.path.dirname(__file__)), "fbgemm_gpu_py.so")))
{% endif %}


def invoke(
    common_args: CommonArgs,
    optimizer_args: OptimizerArgs,
    {% if "momentum1_dev" in args.split_function_arg_names %}
    momentum1: Momentum,
    {% endif %}
    {% if "momentum2_dev" in args.split_function_arg_names %}
    momentum2: Momentum,
    {% endif %}
    {% if "iter" in args.split_function_arg_names %}
    iter: int,
    {% endif %}
) -> torch.Tensor:
    if (common_args.host_weights.numel() > 0):
        return torch.ops.fb.split_embedding_codegen_lookup_{{ optimizer }}_function_cpu(
            # common_args
            host_weights=common_args.host_weights,
            weights_placements=common_args.weights_placements,
            weights_offsets=common_args.weights_offsets,
            D_offsets=common_args.D_offsets,
            total_D=common_args.total_D,
            max_D=common_args.max_D,
            hash_size_cumsum=common_args.hash_size_cumsum,
            total_hash_size_bits=common_args.total_hash_size_bits,
            indices=common_args.indices,
            offsets=common_args.offsets,
            pooling_mode=common_args.pooling_mode,
            indice_weights=common_args.indice_weights,
            feature_requires_grad=common_args.feature_requires_grad,
            # optimizer_args
            gradient_clipping = optimizer_args.gradient_clipping,
            max_gradient=optimizer_args.max_gradient,
            stochastic_rounding=optimizer_args.stochastic_rounding,
            {% if "learning_rate" in args.split_function_arg_names %}
            learning_rate=optimizer_args.learning_rate,
            {% endif %}
            {% if "eps" in args.split_function_arg_names %}
            eps=optimizer_args.eps,
            {% endif %}
            {% if "beta1" in args.split_function_arg_names %}
            beta1=optimizer_args.beta1,
            {% endif %}
            {% if "beta2" in args.split_function_arg_names %}
            beta2=optimizer_args.beta2,
            {% endif %}
            {% if "weight_decay" in args.split_function_arg_names %}
            weight_decay=optimizer_args.weight_decay,
            {% endif %}
            {% if "eta" in args.split_function_arg_names %}
            eta=optimizer_args.eta,
            {% endif %}
            {% if "momentum" in args.split_function_arg_names %}
            momentum=optimizer_args.momentum,
            {% endif %}
            # momentum1
            {% if "momentum1_dev" in args.split_function_arg_names %}
            momentum1_host=momentum1.host,
            momentum1_offsets=momentum1.offsets,
            momentum1_placements=momentum1.placements,
            {% endif %}
            # momentum2
            {% if "momentum2_dev" in args.split_function_arg_names %}
            momentum2_host=momentum2.host,
            momentum2_offsets=momentum2.offsets,
            momentum2_placements=momentum2.placements,
            {% endif %}
            # iter
            {% if "iter" in args.split_function_arg_names %}
            iter=iter,
            {% endif %}
        )
    else:
        return torch.ops.fb.split_embedding_codegen_lookup_{{ optimizer }}_function(
            # common_args
            {% if not dense %}
            placeholder_autograd_tensor=common_args.placeholder_autograd_tensor,
            # output_dtype=common_args.output_dtype,
            {% endif %}
            dev_weights=common_args.dev_weights,
            uvm_weights=common_args.uvm_weights,
            lxu_cache_weights=common_args.lxu_cache_weights,
            weights_placements=common_args.weights_placements,
            weights_offsets=common_args.weights_offsets,
            D_offsets=common_args.D_offsets,
            total_D=common_args.total_D,
            max_D=common_args.max_D,
            hash_size_cumsum=common_args.hash_size_cumsum,
            total_hash_size_bits=common_args.total_hash_size_bits,
            indices=common_args.indices,
            offsets=common_args.offsets,
            pooling_mode=common_args.pooling_mode,
            indice_weights=common_args.indice_weights,
            feature_requires_grad=common_args.feature_requires_grad,
            lxu_cache_locations=common_args.lxu_cache_locations,
            # optimizer_args
            gradient_clipping = optimizer_args.gradient_clipping,
            max_gradient=optimizer_args.max_gradient,
            stochastic_rounding=optimizer_args.stochastic_rounding,
            {% if "learning_rate" in args.split_function_arg_names %}
            learning_rate=optimizer_args.learning_rate,
            {% endif %}
            {% if "eps" in args.split_function_arg_names %}
            eps=optimizer_args.eps,
            {% endif %}
            {% if "beta1" in args.split_function_arg_names %}
            beta1=optimizer_args.beta1,
            {% endif %}
            {% if "beta2" in args.split_function_arg_names %}
            beta2=optimizer_args.beta2,
            {% endif %}
            {% if "weight_decay" in args.split_function_arg_names %}
            weight_decay=optimizer_args.weight_decay,
            {% endif %}
            {% if "eta" in args.split_function_arg_names %}
            eta=optimizer_args.eta,
            {% endif %}
            {% if "momentum" in args.split_function_arg_names %}
            momentum=optimizer_args.momentum,
            {% endif %}
            # momentum1
            {% if "momentum1_dev" in args.split_function_arg_names %}
            momentum1_dev=momentum1.dev,
            momentum1_uvm=momentum1.uvm,
            momentum1_offsets=momentum1.offsets,
            momentum1_placements=momentum1.placements,
            {% endif %}
            # momentum2
            {% if "momentum2_dev" in args.split_function_arg_names %}
            momentum2_dev=momentum2.dev,
            momentum2_uvm=momentum2.uvm,
            momentum2_offsets=momentum2.offsets,
            momentum2_placements=momentum2.placements,
            {% endif %}
            # iter
            {% if "iter" in args.split_function_arg_names %}
            iter=iter,
            {% endif %}
        )
